{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Overview: \n",
    "This code implements one of the multiple ways of multi-model RAG. It extracts and processes text and images from PDFs, utilizing a multi-modal Retrieval-Augmented Generation (RAG) system for summarizing and retrieving content for question answering.\n",
    "\n",
    "### Key Components:\n",
    "   - **PyMuPDF**: For extracting text and images from PDFs.\n",
    "   - **Gemini 1.5-flash model**: To summarize images and tables.\n",
    "   - **Cohere Embeddings**: For embedding document splits.\n",
    "   - **Chroma Vectorstore**: To store and retrieve document embeddings.\n",
    "   - **LangChain**: To orchestrate the retrieval and generation pipeline."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Diagram:\n",
    "   <img src=\"../images/multi_model_rag_with_captioning.svg\" alt=\"Reliable-RAG\" width=\"300\">\n",
    "\n",
    "### Motivation: \n",
    "Efficiently summarize complex documents to facilitate easy retrieval and concise responses for multi-modal data.\n",
    "\n",
    "### Method Details:\n",
    "   - Text and images are extracted from the PDF using PyMuPDF.\n",
    "   - Summarization is performed on extracted images and tables using Gemini.\n",
    "   - Embeddings are generated via Cohere for storage in Chroma.\n",
    "   - A similarity-based retriever fetches relevant sections based on the query.\n",
    "\n",
    "### Benefits:\n",
    "   - Simplified retrieval from complex, multi-modal documents.\n",
    "   - Streamlined Q&A process for both text and images.\n",
    "   - Flexible architecture for expanding to more document types.\n",
    "\n",
    "### Implementation:\n",
    "   - Documents are split into chunks with overlap using a text splitter.\n",
    "   - Summarized text and image content are stored as vectors.\n",
    "   - Queries are handled by retrieving relevant document segments and generating concise answers.\n",
    "\n",
    "### Summary: \n",
    "The project enables multi-modal document processing and retrieval, providing concise, relevant responses by combining state-of-the-art LLMs and vector-based retrieval systems."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "True"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import fitz  # PyMuPDF\n",
    "from PIL import Image\n",
    "import io\n",
    "import os\n",
    "from dotenv import load_dotenv\n",
    "\n",
    "import google.generativeai as genai\n",
    "from langchain_core.prompts import ChatPromptTemplate\n",
    "from langchain_core.documents import Document\n",
    "from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
    "from langchain_community.vectorstores import Chroma\n",
    "from langchain_cohere import ChatCohere, CohereEmbeddings\n",
    "\n",
    "load_dotenv()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Download the \"Attention is all you need\" paper"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--2024-09-20 19:19:26--  https://arxiv.org/pdf/1706.03762\n",
      "Resolving arxiv.org (arxiv.org)... 151.101.195.42, 151.101.3.42, 151.101.67.42, ...\n",
      "Connecting to arxiv.org (arxiv.org)|151.101.195.42|:443... connected.\n",
      "HTTP request sent, awaiting response... 200 OK\n",
      "Length: 2215244 (2.1M) [application/pdf]\n",
      "Saving to: ‘1706.03762’\n",
      "\n",
      "1706.03762          100%[===================>]   2.11M  13.3MB/s    in 0.2s    \n",
      "\n",
      "2024-09-20 19:19:26 (13.3 MB/s) - ‘1706.03762’ saved [2215244/2215244]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "!wget https://arxiv.org/pdf/1706.03762\n",
    "!mv 1706.03762 attention_is_all_you_need.pdf"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Data Extraction"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "text_data = []\n",
    "img_data = []"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "with fitz.open('attention_is_all_you_need.pdf') as pdf_file:\n",
    "    # Create a directory to store the images\n",
    "    if not os.path.exists(\"extracted_images\"):\n",
    "        os.makedirs(\"extracted_images\")\n",
    "\n",
    "    # Loop through every page in the PDF\n",
    "    for page_number in range(len(pdf_file)):\n",
    "        page = pdf_file[page_number]\n",
    "        \n",
    "        # Get the text on page\n",
    "        text = page.get_text().strip()\n",
    "        text_data.append({\"response\": text, \"name\": page_number+1})\n",
    "        # Get the list of images on the page\n",
    "        images = page.get_images(full=True)\n",
    "\n",
    "        # Loop through all images found on the page\n",
    "        for image_index, img in enumerate(images, start=0):\n",
    "            xref = img[0]  # Get the XREF of the image\n",
    "            base_image = pdf_file.extract_image(xref)  # Extract the image\n",
    "            image_bytes = base_image[\"image\"]  # Get the image bytes\n",
    "            image_ext = base_image[\"ext\"]  # Get the image extension\n",
    "            \n",
    "            # Load the image using PIL and save it\n",
    "            image = Image.open(io.BytesIO(image_bytes))\n",
    "            image.save(f\"extracted_images/image_{page_number+1}_{image_index+1}.{image_ext}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "genai.configure(api_key=os.getenv('GOOGLE_API_KEY'))\n",
    "model = genai.GenerativeModel(model_name=\"gemini-1.5-flash\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Image Captioning"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "for img in os.listdir(\"extracted_images\"):\n",
    "    image = Image.open(f\"extracted_images/{img}\")\n",
    "    response = model.generate_content([image, \"You are an assistant tasked with summarizing tables, images and text for retrieval. \\\n",
    "    These summaries will be embedded and used to retrieve the raw text or table elements \\\n",
    "    Give a concise summary of the table or text that is well optimized for retrieval. Table or text or image:\"])\n",
    "    img_data.append({\"response\": response.text, \"name\": img})"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Vectostore"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Set embeddings\n",
    "embedding_model = CohereEmbeddings(model=\"embed-english-v3.0\")\n",
    "\n",
    "# Load the document\n",
    "docs_list = [Document(page_content=text['response'], metadata={\"name\": text['name']}) for text in text_data]\n",
    "img_list = [Document(page_content=img['response'], metadata={\"name\": img['name']}) for img in img_data]\n",
    "\n",
    "# Split\n",
    "text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(\n",
    "    chunk_size=400, chunk_overlap=50\n",
    ")\n",
    "\n",
    "doc_splits = text_splitter.split_documents(docs_list)\n",
    "img_splits = text_splitter.split_documents(img_list)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Add to vectorstore\n",
    "vectorstore = Chroma.from_documents(\n",
    "    documents=doc_splits + img_splits, # adding the both text and image splits\n",
    "    collection_name=\"multi_model_rag\",\n",
    "    embedding=embedding_model,\n",
    ")\n",
    "\n",
    "retriever = vectorstore.as_retriever(\n",
    "                search_type=\"similarity\",\n",
    "                search_kwargs={'k': 1}, # number of documents to retrieve\n",
    "            )"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Query"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "query = \"What is the BLEU score of the Transformer (base model)?\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "docs = retriever.invoke(query)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Output"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "The Transformer (base model) achieves a BLEU score of 27.3.\n"
     ]
    }
   ],
   "source": [
    "from langchain_core.output_parsers import StrOutputParser\n",
    "\n",
    "# Prompt\n",
    "system = \"\"\"You are an assistant for question-answering tasks. Answer the question based upon your knowledge. \n",
    "Use three-to-five sentences maximum and keep the answer concise.\"\"\"\n",
    "prompt = ChatPromptTemplate.from_messages(\n",
    "    [\n",
    "        (\"system\", system),\n",
    "        (\"human\", \"Retrieved documents: \\n\\n <docs>{documents}</docs> \\n\\n User question: <question>{question}</question>\"),\n",
    "    ]\n",
    ")\n",
    "\n",
    "# LLM\n",
    "llm = ChatCohere(model=\"command-r-plus\", temperature=0)\n",
    "\n",
    "# Chain\n",
    "rag_chain = prompt | llm | StrOutputParser()\n",
    "\n",
    "# Run\n",
    "generation = rag_chain.invoke({\"documents\":docs[0].page_content, \"question\": query})\n",
    "print(generation)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "test",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.0"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
